{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ShuffleNet图像分类\n",
    "\n",
    "## ShuffleNet网络介绍\n",
    "\n",
    "ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型，和MobileNet, SqueezeNet等一样主要应用在移动端，所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作：Pointwise Group Convolution和Channel Shuffle，这在保持精度的同时大大降低了模型的计算量。因此，ShuffleNetV1和MobileNet类似，都是通过设计更高效的网络结构来实现模型的压缩和加速。\n",
    "\n",
    "\n",
    "如下图所示，ShuffleNet在保持不低的准确率的前提下，将参数量几乎降低到了最小，因此其运算速度较快，单位参数量对模型准确率的贡献非常高。\n",
    "\n",
    "![shufflenet1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_1.png)\n",
    "\n",
    "> 图片来源：Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.\n",
    "\n",
    "## 模型架构\n",
    "\n",
    "ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进，在较小的计算量的情况下达到了较高的准确率。\n",
    "\n",
    "### Pointwise Group Convolution\n",
    "\n",
    "Group Convolution（分组卷积）原理如下图所示，相比于普通的卷积操作，分组卷积的情况下，每一组的卷积核大小为in_channels/g\\*k\\*k，一共有g组，所有组共有(in_channels/g\\*k\\*k)\\*out_channels个参数，是正常卷积参数的1/g。分组卷积中，每个卷积核只处理输入特征图的一部分通道，**其优点在于参数量会有所降低，但输出通道数仍等于卷积核的数量**。\n",
    "\n",
    "![shufflenet2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_2.png)\n",
    "\n",
    "> 图片来源：Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.\n",
    "\n",
    "Depthwise Convolution（深度可分离卷积）将组数g分为和输入通道相等的`in_channels`，然后对每一个`in_channels`做卷积操作，每个卷积核只处理一个通道，记卷积核大小为1\\*k\\*k，则卷积核参数量为：in_channels\\*k\\*k，**得到的feature maps通道数与输入通道数相等**；\n",
    "\n",
    "Pointwise Group Convolution（逐点分组卷积）在分组卷积的基础上，令**每一组的卷积核大小为** $1\\times 1$，卷积核参数量为(in_channels/g\\*1\\*1)\\*out_channels。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/usr/local/miniconda3/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n"
     ]
    }
   ],
   "source": [
    "from mindspore import nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore import Tensor\n",
    "\n",
    "class GroupConv(nn.Cell):\n",
    "    def __init__(self, in_channels, out_channels, kernel_size,\n",
    "                 stride, pad_mode=\"pad\", pad=0, groups=1, has_bias=False):\n",
    "        super(GroupConv, self).__init__()\n",
    "        self.groups = groups\n",
    "        self.convs = nn.CellList()\n",
    "        for _ in range(groups):\n",
    "            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,\n",
    "                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,\n",
    "                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))\n",
    "\n",
    "    def construct(self, x):\n",
    "        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)\n",
    "        outputs = ()\n",
    "        for i in range(self.groups):\n",
    "            outputs = outputs + (self.convs[i](features[i].astype(\"float32\")),)\n",
    "        out = ops.cat(outputs, axis=1)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Channel Shuffle\n",
    "\n",
    "Group Convolution的弊端在于不同组别的通道无法进行信息交流，堆积GConv层后一个问题是不同组之间的特征图是不通信的，这就好像分成了g个互不相干的道路，每一个人各走各的，**这可能会降低网络的特征提取能力**。这也是Xception，MobileNet等网络采用密集的1x1卷积（Dense Pointwise Convolution）的原因。\n",
    "\n",
    "为了解决不同组别通道“近亲繁殖”的问题，ShuffleNet优化了大量密集的1x1卷积（在使用的情况下计算量占用率达到了惊人的93.4%），引入Channel Shuffle机制（通道重排）。这项操作直观上表现为将不同分组通道**均匀分散重组**，使网络在下一层能处理不同组别通道的信息。\n",
    "\n",
    "![shufflenet3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_3.png)\n",
    "\n",
    "如下图所示，对于g组，每组有n个通道的特征图，首先reshape成g行n列的矩阵，再将矩阵转置成n行g列，最后进行flatten操作，得到新的排列。这些操作都是可微分可导的且计算简单，在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。\n",
    "\n",
    "![shufflenet4](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_4.png)\n",
    "\n",
    "为了阅读方便，将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ShuffleNet模块\n",
    "\n",
    "如下图所示，ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), (c)的更改：\n",
    "\n",
    "1. 将开始和最后的$1\\times 1$卷积模块（降维、升维）改成Point Wise Group Convolution；\n",
    "\n",
    "2. 为了进行不同通道的信息交流，再降维之后进行Channel Shuffle；\n",
    "\n",
    "3. 降采样模块中，$3 \\times 3$ Depth Wise Convolution的步长设置为2，长宽降为原来的一般，因此shortcut中采用步长为2的$3\\times 3$平均池化，并把相加改成拼接。\n",
    "\n",
    "![shufflenet5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_5.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShuffleV1Block(nn.Cell):\n",
    "    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):\n",
    "        super(ShuffleV1Block, self).__init__()\n",
    "        self.stride = stride\n",
    "        pad = ksize // 2\n",
    "        self.group = group\n",
    "        if stride == 2:\n",
    "            outputs = oup - inp\n",
    "        else:\n",
    "            outputs = oup\n",
    "        self.relu = nn.ReLU()\n",
    "        branch_main_1 = [\n",
    "            GroupConv(in_channels=inp, out_channels=mid_channels,\n",
    "                      kernel_size=1, stride=1, pad_mode=\"pad\", pad=0,\n",
    "                      groups=1 if first_group else group),\n",
    "            nn.BatchNorm2d(mid_channels),\n",
    "            nn.ReLU(),\n",
    "        ]\n",
    "        branch_main_2 = [\n",
    "            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,\n",
    "                      pad_mode='pad', padding=pad, group=mid_channels,\n",
    "                      weight_init='xavier_uniform', has_bias=False),\n",
    "            nn.BatchNorm2d(mid_channels),\n",
    "            GroupConv(in_channels=mid_channels, out_channels=outputs,\n",
    "                      kernel_size=1, stride=1, pad_mode=\"pad\", pad=0,\n",
    "                      groups=group),\n",
    "            nn.BatchNorm2d(outputs),\n",
    "        ]\n",
    "        self.branch_main_1 = nn.SequentialCell(branch_main_1)\n",
    "        self.branch_main_2 = nn.SequentialCell(branch_main_2)\n",
    "        if stride == 2:\n",
    "            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')\n",
    "\n",
    "    def construct(self, old_x):\n",
    "        left = old_x\n",
    "        right = old_x\n",
    "        out = old_x\n",
    "        right = self.branch_main_1(right)\n",
    "        if self.group > 1:\n",
    "            right = self.channel_shuffle(right)\n",
    "        right = self.branch_main_2(right)\n",
    "        if self.stride == 1:\n",
    "            out = self.relu(left + right)\n",
    "        elif self.stride == 2:\n",
    "            left = self.branch_proj(left)\n",
    "            out = ops.cat((left, right), 1)\n",
    "            out = self.relu(out)\n",
    "        return out\n",
    "\n",
    "    def channel_shuffle(self, x):\n",
    "        batchsize, num_channels, height, width = ops.shape(x)\n",
    "        group_channels = num_channels // self.group\n",
    "        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))\n",
    "        x = ops.transpose(x, (0, 2, 1, 3, 4))\n",
    "        x = ops.reshape(x, (batchsize, num_channels, height, width))\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建ShuffleNet网络\n",
    "\n",
    "ShuffleNet网络结构如下图所示，以输入图像$224 \\times 224$，组数3（g = 3）为例，首先通过数量24，卷积核大小为$3 \\times 3$，stride为2的卷积层，输出特征图大小为$112 \\times 112$，channel为24；然后通过stride为2的最大池化层，输出特征图大小为$56 \\times 56$，channel数不变；再堆叠3个ShuffleNet模块（Stage2, Stage3, Stage4），三个模块分别重复4次、8次、4次，其中每个模块开始先经过一次下采样模块（上图(c)），使特征图长宽减半，channel翻倍（Stage2的下采样模块除外，将channel数从24变为240）；随后经过全局平均池化，输出大小为$1 \\times 1 \\times 960$，再经过全连接层和softmax，得到分类概率。\n",
    "\n",
    "![shufflenet6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/tutorials/application/source_zh_cn/cv/images/shufflenet_6.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShuffleNetV1(nn.Cell):\n",
    "    def __init__(self, n_class=1000, model_size='2.0x', group=3):\n",
    "        super(ShuffleNetV1, self).__init__()\n",
    "        print('model size is ', model_size)\n",
    "        self.stage_repeats = [4, 8, 4]\n",
    "        self.model_size = model_size\n",
    "        if group == 3:\n",
    "            if model_size == '0.5x':\n",
    "                self.stage_out_channels = [-1, 12, 120, 240, 480]\n",
    "            elif model_size == '1.0x':\n",
    "                self.stage_out_channels = [-1, 24, 240, 480, 960]\n",
    "            elif model_size == '1.5x':\n",
    "                self.stage_out_channels = [-1, 24, 360, 720, 1440]\n",
    "            elif model_size == '2.0x':\n",
    "                self.stage_out_channels = [-1, 48, 480, 960, 1920]\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "        elif group == 8:\n",
    "            if model_size == '0.5x':\n",
    "                self.stage_out_channels = [-1, 16, 192, 384, 768]\n",
    "            elif model_size == '1.0x':\n",
    "                self.stage_out_channels = [-1, 24, 384, 768, 1536]\n",
    "            elif model_size == '1.5x':\n",
    "                self.stage_out_channels = [-1, 24, 576, 1152, 2304]\n",
    "            elif model_size == '2.0x':\n",
    "                self.stage_out_channels = [-1, 48, 768, 1536, 3072]\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "        input_channel = self.stage_out_channels[1]\n",
    "        self.first_conv = nn.SequentialCell(\n",
    "            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),\n",
    "            nn.BatchNorm2d(input_channel),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')\n",
    "        features = []\n",
    "        for idxstage in range(len(self.stage_repeats)):\n",
    "            numrepeat = self.stage_repeats[idxstage]\n",
    "            output_channel = self.stage_out_channels[idxstage + 2]\n",
    "            for i in range(numrepeat):\n",
    "                stride = 2 if i == 0 else 1\n",
    "                first_group = idxstage == 0 and i == 0\n",
    "                features.append(ShuffleV1Block(input_channel, output_channel,\n",
    "                                               group=group, first_group=first_group,\n",
    "                                               mid_channels=output_channel // 4, ksize=3, stride=stride))\n",
    "                input_channel = output_channel\n",
    "        self.features = nn.SequentialCell(features)\n",
    "        self.globalpool = nn.AvgPool2d(7)\n",
    "        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.first_conv(x)\n",
    "        x = self.maxpool(x)\n",
    "        x = self.features(x)\n",
    "        x = self.globalpool(x)\n",
    "        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设置运行环境\n",
    "\n",
    "由于资源限制，需开启性能优化模式，具体设置如下参数：\n",
    "\n",
    " max_device_memory=\"2GB\" : 设置设备可用的最大内存为2GB。\n",
    "\n",
    " mode=mindspore.GRAPH_MODE : 表示在GRAPH_MODE模式中运行。\n",
    "\n",
    " device_target=\"Ascend\" : 表示待运行的目标设备为Ascend。\n",
    "\n",
    " jit_config={\"jit_level\":\"O2\"} : 编译优化级别开启极致性能优化，使用下沉的执行方式。\n",
    "\n",
    " ascend_config={\"precision_mode\":\"allow_mix_precision\"} : 自动混合精度，自动将部分算子的精度降低到float16或bfloat16。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "\n",
    "mindspore.set_context(max_device_memory=\"2GB\", mode=mindspore.GRAPH_MODE, device_target=\"Ascend\",  jit_config={\"jit_level\":\"O2\"}, ascend_config={\"precision_mode\":\"allow_mix_precision\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集准备与加载\n",
    "\n",
    "采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像，均匀地分为10个类别，其中50000张图片作为训练集，10000图片作为测试集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB)\n",
      "\n",
      "file_sizes: 100%|████████████████████████████| 170M/170M [00:41<00:00, 4.13MB/s]\n",
      "Extracting tar.gz file...\n",
      "Successfully downloaded / unzipped to ./dataset\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'./dataset'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from download import download\n",
    "\n",
    "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz\"\n",
    "\n",
    "download(url, \"./dataset\", kind=\"tar.gz\", replace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore as ms\n",
    "from mindspore.dataset import Cifar10Dataset\n",
    "from mindspore.dataset import vision, transforms\n",
    "\n",
    "def get_dataset(train_dataset_path, batch_size, usage):\n",
    "    image_trans = []\n",
    "    if usage == \"train\":\n",
    "        image_trans = [\n",
    "            vision.RandomCrop((32, 32), (4, 4, 4, 4)),\n",
    "            vision.RandomHorizontalFlip(prob=0.5),\n",
    "            vision.Resize((224, 224)),\n",
    "            vision.Rescale(1.0 / 255.0, 0.0),\n",
    "            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n",
    "            vision.HWC2CHW()\n",
    "        ]\n",
    "    elif usage == \"test\":\n",
    "        image_trans = [\n",
    "            vision.Resize((224, 224)),\n",
    "            vision.Rescale(1.0 / 255.0, 0.0),\n",
    "            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n",
    "            vision.HWC2CHW()\n",
    "        ]\n",
    "    label_trans = transforms.TypeCast(ms.int32)\n",
    "    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)\n",
    "    dataset = dataset.map(image_trans, 'image')\n",
    "    dataset = dataset.map(label_trans, 'label')\n",
    "    dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型预测\n",
    "\n",
    "在CIFAR-10的测试集上对模型进行预测，并将预测结果可视化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import mindspore.dataset as ds\n",
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "from mindspore.train import Model\n",
    "\n",
    "# download ckpt\n",
    "shufflenet_url = \"https://modelers.cn/coderepo/web/v1/file/MindSpore-Lab/cluoud_obs/main/media/examples/mindspore-courses/orange-pi-online-infer/05-ShuffleNet/shufflenetv1_1-150_390.ckpt\"\n",
    "path = \"./shufflenetv1_1-150_390.ckpt\"\n",
    "ckpt_path = download(shufflenet_url, path, replace=True)\n",
    "\n",
    "net = ShuffleNetV1(model_size=\"2.0x\", n_class=10)\n",
    "show_lst = []\n",
    "param_dict = load_checkpoint(ckpt_path)\n",
    "load_param_into_net(net, param_dict)\n",
    "model = Model(net)\n",
    "dataset_predict = ds.Cifar10Dataset(dataset_dir=\"./dataset/cifar-10-batches-bin\", shuffle=False, usage=\"train\")\n",
    "dataset_show = ds.Cifar10Dataset(dataset_dir=\"./dataset/cifar-10-batches-bin\", shuffle=False, usage=\"train\")\n",
    "dataset_show = dataset_show.batch(16)\n",
    "show_images_lst = next(dataset_show.create_dict_iterator())[\"image\"].asnumpy()\n",
    "image_trans = [\n",
    "    vision.RandomCrop((32, 32), (4, 4, 4, 4)),\n",
    "    vision.RandomHorizontalFlip(prob=0.5),\n",
    "    vision.Resize((224, 224)),\n",
    "    vision.Rescale(1.0 / 255.0, 0.0),\n",
    "    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n",
    "    vision.HWC2CHW()\n",
    "        ]\n",
    "dataset_predict = dataset_predict.map(image_trans, 'image')\n",
    "dataset_predict = dataset_predict.batch(16)\n",
    "class_dict = {0:\"airplane\", 1:\"automobile\", 2:\"bird\", 3:\"cat\", 4:\"deer\", 5:\"dog\", 6:\"frog\", 7:\"horse\", 8:\"ship\", 9:\"truck\"}\n",
    "# 推理效果展示(上方为预测的结果，下方为推理效果图片)\n",
    "plt.figure(figsize=(16, 5))\n",
    "predict_data = next(dataset_predict.create_dict_iterator())\n",
    "output = model.predict(ms.Tensor(predict_data['image']))\n",
    "pred = np.argmax(output.asnumpy(), axis=1)\n",
    "index = 0\n",
    "for image in show_images_lst:\n",
    "    plt.subplot(2, 8, index+1)\n",
    "    plt.title('{}'.format(class_dict[pred[index]]))\n",
    "    index += 1\n",
    "    plt.imshow(image)\n",
    "    plt.axis(\"off\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "5d834606fceea8447e1f2d2b8feecce2028f1eecd4b0eabaa8daf1aeed30752e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
